import matplotlib.pyplot as plt
import json
import numpy as np

from utils import flatten, average

dataset_questions = json.load(open('Datasets/individual-questions.json', "r", encoding="utf8"))
keys = list(dataset_questions.keys())
frames = [dataset_questions[key]["frame"] for key in keys]
index_to_keys = {i:keys[i] for i in range(len(keys))}
key_to_index = {keys[i]:i for i in range(len(keys))}
pairs = []
pair_names = []

for key in keys:
    pair = dataset_questions[key]["opposite frame"]
    if (key, pair) not in pairs and (pair, key) not in pairs:
        if dataset_questions[key]["frame"] == "gain" or pair == "null":
            pairs += [(key, pair)]
            pair_names += ["{}&{}".format(key, pair) if pair != "null" else key]
        else:
            pairs += [(pair, key)]
            pair_names += ["{} & {}".format(pair, key) if pair != "null" else key]

print(len(pairs), pairs)

def figure5():
    colours = ["#B8E793" if frame.lower().strip() == "gain" else "#E79893" if frame.lower().strip() == "loss" else "#E793E2" for frame in frames]
    print(colours)
    # {0: 77, 1: 71, 2: 82, 3: 56, 4: 65, 5: 71, 6: 76, 7: 77, 8: 80, 9: 59, 10: 59, 11: 63, 12: 87, 13: 61, 14: 82, 15: 95, 16: 48, 17: 90, 18: 86, 19: 82, 20: 55, 21: 49, 22: 84, 23: 66, 24: 86, 25: 62, 26: 89, 27: 88, 28: 95, 29: 55, 30: 68, 31: 72, 32: 88, 33: 52, 34: 84, 35: 76, 36: 89, 37: 94}
    # {0: 74, 1: 76, 2: 80, 3: 58, 4: 60, 5: 75, 6: 74, 7: 77, 8: 82, 9: 53, 10: 66, 11: 50, 12: 90, 13: 61, 14: 81, 15: 97, 16: 46, 17: 89, 18: 87, 19: 79, 20: 56, 21: 45, 22: 82, 23: 61, 24: 88, 25: 63, 26: 85, 27: 86, 28: 93, 29: 58, 30: 78, 31: 58, 32: 82, 33: 51, 34: 90, 35: 76, 36: 94, 37: 97}
    result = [0.6563636363636364, 0.5967768595041323, 0.6976859504132231, 0.4828099173553719, 0.5671900826446281, 0.5967768595041323, 0.628099173553719, 0.6463636363636364, 0.6711570247933884, 0.48760330578512395, 0.48760330578512395, 0.5206611570247934, 0.73900826446281, 0.5141322314049587, 0.6976859504132231, 0.7951239669421488, 0.39669421487603307, 0.763801652892562, 0.7307438016528925, 0.6976859504132231, 0.45454545454545453, 0.4249586776859504, 0.6942148760330579, 0.5454545454545454, 0.7307438016528925, 0.552396694214876, 0.7555371900826446, 0.7272727272727273, 0.7951239669421488, 0.45454545454545453, 0.5719834710743802, 0.5950413223140496, 0.7472727272727273, 0.4297520661157025, 0.6942148760330579, 0.648099173553719, 0.7555371900826446, 0.7868595041322314]

    gain_results = [x for i,x in enumerate(result) if frames[i] == "gain"]
    loss_results = [x for i,x in enumerate(result) if frames[i] == "loss"]
    null_results = [x for i,x in enumerate(result) if frames[i] == "null"]
    print(average(result), average(gain_results), average(loss_results), average(null_results), average(gain_results + loss_results))

    plt.bar(keys, result, color=colours)
    plt.xticks(rotation=90)
    handles = [plt.Rectangle((0,0),1,1, color=col) for col in ["#B8E793", "#E79893", "#E793E2"]]
    plt.legend(handles, ["Gain Frame", "Loss Frame", "Null Frame"], bbox_to_anchor =(0.92, 1.15), ncol = 3)
    plt.show()


def figurev1():

    dataset = open("Datasets/individual.csv", "r", encoding="utf8").read().strip().split("\n")[1:]
    answers = [[int(x) for x in data.split(",")[3:]] for data in dataset]
    choice_1_count = np.sum(answers, axis=0)
    choice_0_count = np.array([len(answers) - count for count in choice_1_count])

    c0_colour = ['#74E474' if frame.lower().strip() == "gain" else '#396AD1' if frame.lower().strip() == "loss" else "#5439D1" for frame in frames]
    c1_colour = [
        '#E47474' if frame.lower().strip() == "gain" else '#BBBC3D' if frame.lower().strip() == "loss" else "#BC7E3D"
        for frame in frames]

    width = 0.45
    plt.bar(keys, choice_0_count/121, width, color='#74E474')
    plt.bar(keys, choice_1_count/121, width, bottom=choice_0_count/121, color="#E47474")
    plt.xticks(rotation=90)
    plt.show()

def figure4():

    dataset = open("Datasets/individual.csv", "r", encoding="utf8").read().strip().split("\n")[1:]
    answers = [[int(x) for x in data.split(",")[3:]] for data in dataset]

    choice_1_count = np.sum(answers, axis=0)
    choice_0_count = np.array([len(answers) - count for count in choice_1_count])

    p0, p1, n0, n1 = [], [], [], []
    for pair in pairs:
        x,y = pair
        p0 += [0 if x not in key_to_index else choice_0_count[key_to_index[x]]]
        p1 += [0 if y not in key_to_index else choice_0_count[key_to_index[y]]]
        n0 += [0 if x not in key_to_index else choice_1_count[key_to_index[x]]]
        n1 += [0 if y not in key_to_index else choice_1_count[key_to_index[y]]]
    p0 = np.array(p0)
    p1 = np.array(p1)
    n0 = np.array(n0)
    n1 = np.array(n1)

    fig, ax = plt.subplots(figsize=(17, 8))
    width = 0.35
    x = np.arange(len(pairs))
    # ax.bar(x - width/2, choice_0_count/121, width, color='#74E474')
    # ax.bar(x - width/2, choice_1_count/121, width, bottom=choice_0_count/121, color='#E47474')

    ax.bar(x - width/2, p0/121, width, color='#74E474', edgecolor='black')
    ax.bar(x - width/2, n0/121, width, bottom=p0/121, color='#E47474', edgecolor='black')
    ax.bar(x + width / 2, p1/ 121, width, color='#74E474', edgecolor='black')
    ax.bar(x + width/2, n1/121, width, bottom=p1/121, color='#E47474', edgecolor='black')

    ax.set_xticks(x)
    ax.set_xticklabels(pair_names)
    plt.xticks(rotation=90)
    plt.show()

figure5()


